-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU EP] Support GroupQueryAttention #22658
base: main
Are you sure you want to change the base?
Conversation
const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; | ||
const bool has_present_key = output_count > 1 && past_key; | ||
const bool has_attention_bias = attention_bias != nullptr; | ||
const int tile_size = 12; |
Check warning
Code scanning / PREfast
The const variable 'tile_size' can be computed at compile-time. Consider using constexpr (con.5). Warning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
present_value, parameters, context, seqlen_k, total_seqlen_tensor); | ||
} | ||
TensorShape k_new_shape(k_new_dims); | ||
Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line causes a segfault with these models: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile as they have GQA nodes that do not have the optional key
and value
inputs so the Tensor* is a nullptr.
onnxruntime/onnxruntime/core/graph/contrib_ops/bert_defs.cc
Lines 1090 to 1111 in c64459f
.Input(1, | |
"key", | |
"Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", | |
"T", | |
OpSchema::Optional) | |
.Input(2, | |
"value", | |
"Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", | |
"T", | |
OpSchema::Optional) | |
.Input(3, | |
"past_key", | |
"past state key with support for format BNSH. When past_key uses same tensor as present_key" | |
"(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", | |
"T", | |
OpSchema::Optional) | |
.Input(4, | |
"past_value", | |
"past state value with support for format BNSH. When past_value uses same tensor as present_value" | |
"(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", | |
"T", | |
OpSchema::Optional) |
WebgpuAttentionParameters is not copying the value of is_packed_qkv_
from GroupQueryAttentionParameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to correctly initialize is_packed_qkv_
struct WebgpuAttentionParameters { | ||
WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_parameters_(false), | ||
batch_size_(parameters.batch_size), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason WebGPU needs a parameters struct that combines AttentionParameters and GroupQueryAttentionParameters? Feels a little confusing to merge those and wondering why it's necessary if we don't need to do that for other EPs that implement these operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to avoid code duplication. I refactored code into attention used by both GQA and MHA. The CPU version has GQA separate implementation. group_query_attention_helper::CheckInputs() and AttentionBase::CheckInputs output different structs, GroupQueryAttentionParameters and AttentionParameters respectively. WebGPU parameters is a union of these to structs.
514217f
to
d49ecb4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; | ||
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { | ||
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; | |
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { | |
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; | |
shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; | |
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { | |
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; |
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" | ||
" tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" | ||
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" | ||
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" | ||
" }\n"; | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" | |
" tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" | |
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" | |
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" | |
" }\n"; | |
} else { | |
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" | |
" tileK[idx] = " | |
<< (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" | |
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" | |
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" | |
" }\n"; | |
} else { |
const Tensor* seqlen_k) { | ||
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; | ||
const bool has_present_value = output_count > 1 && past_value != nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const Tensor* seqlen_k) { | |
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; | |
const bool has_present_value = output_count > 1 && past_value != nullptr; | |
const Tensor* seqlen_k) { | |
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_; | |
const bool has_present_value = output_count > 1 && past_value != nullptr; |
.TypeConstraint("T", WebGpuSupportedFloatTypes()) | ||
.MayInplace(3, 1) | ||
.MayInplace(4, 2) | ||
.InputMemoryType(OrtMemTypeCPUInput, 6), | ||
GroupQueryAttention); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.TypeConstraint("T", WebGpuSupportedFloatTypes()) | |
.MayInplace(3, 1) | |
.MayInplace(4, 2) | |
.InputMemoryType(OrtMemTypeCPUInput, 6), | |
GroupQueryAttention); | |
.TypeConstraint("T", WebGpuSupportedFloatTypes()) | |
.MayInplace(3, 1) | |
.MayInplace(4, 2) | |
.InputMemoryType(OrtMemTypeCPUInput, 6), | |
GroupQueryAttention); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" | ||
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" | ||
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" | ||
" }\n"; | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" | |
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" | |
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" | |
" }\n"; | |
} else { | |
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" | |
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" | |
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" | |
" }\n"; | |
} else { |
4a072b5
to
5f1fdae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; | ||
const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; | |
const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_; | |
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; | |
const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_; | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
|
||
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; | ||
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; | |
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, | |
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; | |
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, |
int work_group_size = 64; | ||
const int total_sequence_length_comp = (total_sequence_length + components -1) / components; | ||
if (total_sequence_length_comp < work_group_size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int work_group_size = 64; | |
const int total_sequence_length_comp = (total_sequence_length + components -1) / components; | |
if (total_sequence_length_comp < work_group_size) { | |
int work_group_size = 64; | |
const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; | |
if (total_sequence_length_comp < work_group_size) { |
This reverts commit 15c96b3.
Description
Motivation and Context